{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IN THIS NOTEBOOK, WE WILL EXPLORE THE COLBERT AS A RERANKER AND RETRIEVER IN LOCAL MODE. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* If you want to build a server from your colbert local index, please refer [here](https://github.com/stanford-futuredata/ColBERT/blob/main/server.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from colbert.infra.config import ColBERTConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# You can set this environment variable for debugging purposes\n",
    "os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = \"True\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's review the colbert config class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can view the different attributes of the colbert config by uncommenting cell below\n",
    "# for k,v in ColBERTConfig().__dict__.items():\n",
    "#     print(f\"{k} --> {v}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "passages =  [\"It's a piece of cake.\", \"Don't put off until tomorrow what you can do today.\", 'To kill two birds with one stone.', 'Actions speak louder than words.', 'Honesty is the best policy.', 'If you want something done right, do it yourself.', 'The best things in life are free.', \"Don't count your chickens before they hatch.\", 'She sells seashells by the seashore.', 'Practice makes perfect.', \"Where there's a will, there's a way.\", 'Absence makes the heart grow fonder.', 'When the going gets tough, the tough get going.', 'A journey of a thousand miles begins with a single step.', \"You can't have your cake and eat it too.\", \"If you can't beat them, join them.\", 'Keep your friends close and your enemies closer.', \"Don't put all your eggs in one basket.\", \"All's fair in love and war.\", 'Every dog has its day.', 'All good things must come to an end.', 'Once bitten, twice shy.', \"The apple doesn't fall far from the tree.\", 'A penny saved is a penny earned.', \"Don't bite the hand that feeds you.\", 'You reap what you sow.', 'An apple a day keeps the doctor away.', \"One man's trash is another man's treasure.\", 'The squeaky wheel gets the grease.', 'A picture is worth a thousand words.', 'Fortune favors the bold.', 'Practice what you preach.', 'A watched pot never boils.', 'No pain, no gain.', \"You can't make an omelet without breaking eggs.\", \"There's no place like home.\", 'Ask and you shall receive.', 'Let sleeping dogs lie.', 'If the shoe fits, wear it.', 'Every cloud has a silver lining.', 'Look before you leap.', 'The more, the merrier.', 'The grass is always greener on the other side.', 'Beauty is only skin deep.', \"Two wrongs don't make a right.\", 'Beauty is in the eye of the beholder.', 'Necessity is the mother of invention.', 'Out of sight, out of mind.', 'Patience is a virtue.', 'Curiosity killed the cat.', \"If at first you don't succeed, try, try again.\", \"Beggars can't be choosers.\", 'Too many cooks spoil the broth.', 'Easy come, easy go.', \"Don't cry over spilled milk.\", \"There's no such thing as a free lunch.\", 'A bird in the hand is worth two in the bush.', 'Good things come to those who wait.', 'The quick brown fox jumps over the lazy dog.', 'It takes two to tango.', 'A friend in need is a friend indeed.', 'Like father, like son.', 'Let bygones be bygones.', 'Kill two birds with one stone.', 'A penny for your thoughts.', 'I am the master of my fate, I am the captain of my soul.', 'The pen is mightier than the sword.', 'When in Rome, do as the Romans do.', \"Rome wasn't built in a day.\", \"You can't judge a book by its cover.\", \"It's raining cats and dogs.\", 'Make hay while the sun shines.', \"It's better to be safe than sorry.\", 'The early bird catches the worm.', 'To be or not to be, that is the question.', 'Better late than never.']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This tutorial is running from the `examples/integrations/tutorials folder`, hence we need to add the system path for dspy\n",
    "\n",
    "* If you have installed the dspy package, then you don't need to run the below cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../..\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## COLBERT AS RETRIEVER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dspy\n",
    "colbert_config = ColBERTConfig()\n",
    "colbert_config.index_name = \"Colbert-RM\"\n",
    "colbert_config.experiment = \"Colbert-Experiment\"\n",
    "colbert_config.checkpoint = \"colbert-ir/colbertv2.0\"\n",
    "colbert_retriever = dspy.ColBERTv2RetrieverLocal(\n",
    "    passages = passages,load_only=False,\n",
    "    colbert_config=colbert_config\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#CONFIGURE COLBERT IN DSPY\n",
    "dspy.settings.configure(rm=colbert_retriever)\n",
    "\n",
    "retrieved_docs = dspy.Retrieve(k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==\n",
      "#> Input: . What is the meaning of life?, \t\t True, \t\t None\n",
      "#> Output IDs: torch.Size([32]), tensor([ 101,    1, 2054, 2003, 1996, 3574, 1997, 2166, 1029,  102,  103,  103,\n",
      "         103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,  103,\n",
      "         103,  103,  103,  103,  103,  103,  103,  103], device='cuda:0')\n",
      "#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pred = retrieved_docs(\n",
    "    \"What is the meaning of life?\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Prediction(\n",
       "    score=[nan, nan, nan, nan, nan],\n",
       "    pid=[33, 6, 47, 74, 48],\n",
       "    passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error.\n"
     ]
    }
   ],
   "source": [
    "multiple_pred = retrieved_docs(\n",
    "    [\"What is the meaning of life?\",\"Meaning of pain?\"],by_prob=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Prediction(\n",
       "     score=[nan, nan, nan, nan, nan],\n",
       "     pid=[33, 6, 47, 74, 48],\n",
       "     passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']\n",
       " ),\n",
       " Prediction(\n",
       "     score=[nan, nan, nan, nan, nan],\n",
       "     pid=[16, 0, 47, 74, 26],\n",
       "     passages=['Keep your friends close and your enemies closer.', \"It's a piece of cake.\", 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'An apple a day keeps the doctor away.']\n",
       " )]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "multiple_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## COLBERT AS RERANKER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "colbert_config = ColBERTConfig()\n",
    "colbert_config.index_name = 'colbert-ir-index'\n",
    "colbert_reranker = dspy.ColBERTv2RerankerLocal(\n",
    "    checkpoint='colbert-ir/colbertv2.0',colbert_config=colbert_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "dspy.settings.configure(rm=colbert_retriever,reranker=colbert_reranker)\n",
    "\n",
    "retrieve_rerank = dspy.RetrieveThenRerank(k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = retrieve_rerank(\n",
    "    [\"What is the meaning of life?\",\"Meaning of pain?\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Prediction(\n",
       "     score=[nan, nan, nan, nan, nan],\n",
       "     pid=[6, 48, 74, 47, 33],\n",
       "     rerank_score=[15.8359375, 14.2109375, 12.5703125, 11.7890625, 9.1796875],\n",
       "     passages=['The best things in life are free.', 'Patience is a virtue.', 'To be or not to be, that is the question.', 'Out of sight, out of mind.', 'No pain, no gain.']\n",
       " ),\n",
       " Prediction(\n",
       "     score=[nan, nan, nan, nan, nan],\n",
       "     pid=[33, 0, 47, 74, 16],\n",
       "     rerank_score=[19.828125, 12.2890625, 11.171875, 9.09375, 6.8984375],\n",
       "     passages=['No pain, no gain.', \"It's a piece of cake.\", 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Keep your friends close and your enemies closer.']\n",
       " )]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## YOU CAN ALSO COLBERT RERANKER AS STANDALONE MODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install tabulate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tabulate\n",
    "\n",
    "scores_arr = colbert_reranker(\n",
    "    \"What is the meaning of life and pain?\",\n",
    "    # Pass a subset of passages\n",
    "    passages[:10]\n",
    ")\n",
    "\n",
    "tabulate_data = []\n",
    "for idx in np.argsort(scores_arr)[::-1]:\n",
    "    # print(f\"Passage = {passages[idx]} --> Score = {scores_arr[idx]}\")\n",
    "    tabulate_data.append([passages[idx],scores_arr[idx]])\n",
    "\n",
    "table = tabulate.tabulate(tabulate_data,tablefmt=\"html\",headers={'sentence','score'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<thead>\n",
       "<tr><th>score                                              </th><th style=\"text-align: right;\">  sentence</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>The best things in life are free.                  </td><td style=\"text-align: right;\">  12.5156 </td></tr>\n",
       "<tr><td>It&#x27;s a piece of cake.                              </td><td style=\"text-align: right;\">  10      </td></tr>\n",
       "<tr><td>Practice makes perfect.                            </td><td style=\"text-align: right;\">   8.27344</td></tr>\n",
       "<tr><td>Honesty is the best policy.                        </td><td style=\"text-align: right;\">   7.57422</td></tr>\n",
       "<tr><td>To kill two birds with one stone.                  </td><td style=\"text-align: right;\">   7.51953</td></tr>\n",
       "<tr><td>Actions speak louder than words.                   </td><td style=\"text-align: right;\">   7.05469</td></tr>\n",
       "<tr><td>If you want something done right, do it yourself.  </td><td style=\"text-align: right;\">   6.52344</td></tr>\n",
       "<tr><td>Don&#x27;t put off until tomorrow what you can do today.</td><td style=\"text-align: right;\">   3.78711</td></tr>\n",
       "<tr><td>She sells seashells by the seashore.               </td><td style=\"text-align: right;\">   2.77148</td></tr>\n",
       "<tr><td>Don&#x27;t count your chickens before they hatch.       </td><td style=\"text-align: right;\">   1.82227</td></tr>\n",
       "</tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import HTML, display\n",
    "display(HTML(table))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
